{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# IPEX model for text-generation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To load your IPEX model, you can just replace your `AutoModelForXxx` class with the corresponding `IPEXModelForXxx` class. You can set `export=True` to load a PyTorch checkpoint, export your model via TorchScript and apply IPEX optimizations : both operators optimization (replaced with customized IPEX operators) and graph-level optimization (like operators fusion) will be applied on your model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import AutoTokenizer\n",
    "from optimum.intel.ipex import IPEXModelForCausalLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Framework not specified. Using pt to export the model.\n",
      "Passing the argument `library_name` to `get_supported_tasks_for_model_type` is required, but got library_name=None. Defaulting to `transformers`. An error will be raised in a future version of Optimum if `library_name` is not provided.\n",
      "/home/jiqingfe/frameworks.ai.pytorch.ipex-cpu/intel_extension_for_pytorch/frontend.py:462: UserWarning: Conv BatchNorm folding failed during the optimize process.\n",
      "  warnings.warn(\n",
      "/home/jiqingfe/frameworks.ai.pytorch.ipex-cpu/intel_extension_for_pytorch/frontend.py:469: UserWarning: Linear BatchNorm folding failed during the optimize process.\n",
      "  warnings.warn(\n",
      "/home/jiqingfe/miniconda3/envs/ipex/lib/python3.10/site-packages/transformers/modeling_utils.py:4193: FutureWarning: `_is_quantized_training_enabled` is going to be deprecated in transformers 4.39.0. Please use `model.hf_quantizer.is_trainable` instead\n",
      "  warnings.warn(\n",
      "/home/jiqingfe/miniconda3/envs/ipex/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py:801: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
      "  if batch_size <= 0:\n",
      "Passing the argument `library_name` to `get_supported_tasks_for_model_type` is required, but got library_name=None. Defaulting to `transformers`. An error will be raised in a future version of Optimum if `library_name` is not provided.\n",
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n",
      "access to the `model_dtype` attribute is deprecated and will be removed after v1.18.0, please use `_dtype` instead.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Answer the following yes/no question by reasoning step-by-step please. Can you write a whole Haiku in a single tweet? Yes, you can.\n",
      "\n",
      "Yes, I can write Haikus in one tweet. I have no idea how to do that, but I'm sure\n"
     ]
    }
   ],
   "source": [
    "model = IPEXModelForCausalLM.from_pretrained(\"gpt2\", torch_dtype=torch.bfloat16, export=True)\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n",
    "input_sentence = [\"Answer the following yes/no question by reasoning step-by-step please. Can you write a whole Haiku in a single tweet?\"]\n",
    "model_inputs = tokenizer(input_sentence, return_tensors=\"pt\")\n",
    "generation_kwargs = dict(max_new_tokens=32, do_sample=False, num_beams=4, num_beam_groups=1, no_repeat_ngram_size=2, use_cache=True)\n",
    "\n",
    "generated_ids = model.generate(**model_inputs, **generation_kwargs)\n",
    "output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]\n",
    "print(output)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ipex",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
